import random
import numpy as np
import torch
torch.set_float32_matmul_precision('high')
from torch.optim import AdamW
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import get_peft_model, PeftModel
from tqdm import tqdm
import logging
import csv
from datasets import load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import sys, os
import argparse
from functools import partial
import pickle

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.utils import load_checkpoint, print_color 
from data_prep.pre_processing_data import load_and_process_dataset_from_name 
from simple_dpo.train_MO_PO import estimate_validation_loss, collate_fn

# Set up logger
pid = os.getpid()
log_filename = f'my_eval_log_{pid}.log'
logging.basicConfig(filename=log_filename, filemode='w', level=logging.DEBUG, 
                    format='%(asctime)s - %(levelname)s - %(message)s')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_output(model, inputs, tokenizer, max_length=512):
    model.eval()
    
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'], 
            attention_mask=inputs['attention_mask'].to(device),
            # max_length=max_length,
            max_new_tokens=256, 
            pad_token_id=tokenizer.eos_token_id, 
            do_sample=True,  # Ensures greedy decoding
            temperature=0.7,  # Experiment with this value
            #top_k=50,  # Limit the sampling to the top 50 tokens
            top_p=0.9,  # Nucleus sampling
            top_k=0,
            num_return_sequences=1,
            repetition_penalty=1.2  # Adjust to discourage repetition
        )

    # with torch.no_grad():
    #     outputs = model.generate(
    #         inputs['input_ids'], 
    #         attention_mask=inputs['attention_mask'].to(device),
    #         # max_length=max_length,
    #         max_new_tokens=256, 
    #         pad_token_id=tokenizer.eos_token_id, 
    #         do_sample=False,  # Ensures greedy decoding
    #         num_return_sequences=1,
    #         repetition_penalty=1.2  # Adjust to discourage repetition
    #     )
    # Slice off the input part from each output to get only the new tokens
    return [tokenizer.decode(output[inputs['input_ids'].shape[1]:], skip_special_tokens=True) for output in outputs]


def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument('--checkpoint', type=str, required=True, help='Specify the model path')
    parser.add_argument('--model', type=str, required=True, help='Specify the model name')
    parser.add_argument('--use_lora', action="store_true", help='Enable or disable Lora (true/false)')
    parser.add_argument('--quantization', action="store_true", help='Quantization or not')

    args = parser.parse_args()
    return args


def compare_model_parameters(model_a, model_b):
    # Check if both models have the same number of parameters
    if len(list(model_a.parameters())) != len(list(model_b.parameters())):
        logging.info("Models have a different number of parameters.")
        return False

    # Compare each parameter
    for (name_a, param_a), (name_b, param_b) in zip(model_a.named_parameters(), model_b.named_parameters()):
        if not torch.equal(param_a.data, param_b.data):
            logging.info(f"Parameters differ: {name_a} != {name_b}")
            return False
        else:
            logging.info(f"Parameters are the same: {name_a}")

    return True

@torch.no_grad()
def evaluate(peft_model, rank_model, tokenizer, rewardtokenizer, testloader, csv_file, max_length=512):

    peft_model.eval()

    ori_scores_hones = []
    ori_scores_help = []
    ori_scores_instuction = []
    tuned_scores_hones = []
    tuned_scores_help = []
    tuned_instuction = []

    n = 0

    # === Clear CSV file before writing (clean the content first time) ===
    with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["prompt", "response"])  # Write header only
    
    base, ext = os.path.splitext(csv_file)
    new_filename = base + "_org" + ext
    with open(new_filename + "", mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["prompt", "response"])  # Write header only
    
    for batch in tqdm(testloader, desc="Evaluating"):
        inputs = {
            'prompts': batch['prompts'],
            'input_ids': batch['input_ids'],
            'attention_mask': batch['attention_mask']
        }

        output_peft = generate_output(peft_model, inputs, tokenizer)
    
        with peft_model.disable_adapter():
            output_original = generate_output(peft_model, inputs, tokenizer)

        #ori_outputs = [
        #    f"\n\nHuman:\n{prompt}\n\nAssistant:{output}\n" for prompt, output in zip(batch['prompts'], output_original)
        #]
        #peft_outputs = [
        #    f"\n\nHuman:\n{prompt}\n\nAssistant:{output}\n" for prompt, output in zip(batch['prompts'], output_peft)
        #]
        input_template = "[INST] You must read the following conversation carefully and rate the assistant's response from score 0-100 in these aspects: helpfulness, correctness, coherence, honesty, complexity, verbosity\n\nUser: {prompt}\n\nAssistant: {response} [/INST]"
        ori_outputs = [
            input_template.format(prompt=extract_raw_prompt(prompt), response=output) for prompt, output in zip(batch['prompts'], output_original)
        ]
        peft_outputs = [
            input_template.format(prompt=extract_raw_prompt(prompt), response=output) for prompt, output in zip(batch['prompts'], output_peft)
        ]

        # Tokenize the inputs
        inputs_ori = rewardtokenizer.batch_encode_plus(ori_outputs, 
                                                      return_tensors='pt',
                                                      add_special_tokens=True,
                                                      padding=True, 
                                                      truncation=True, 
                                                      max_length=max_length,
                                                      return_attention_mask=True)
        inputs_tuned = rewardtokenizer.batch_encode_plus(peft_outputs, 
                                                         return_tensors='pt', 
                                                         add_special_tokens=True,
                                                         padding=True, 
                                                         truncation=True, 
                                                         max_length=max_length,
                                                         return_attention_mask=True)

        # Move inputs to GPU if available
        inputs_ori = {k: v.to(device) for k, v in inputs_ori.items()}
        inputs_tuned = {k: v.to(device) for k, v in inputs_tuned.items()}

        # Compute scores for both responses
        with torch.no_grad():  # Disable gradient calculations for inference
            ori_score = rank_model(**inputs_ori).logits.cpu().detach().numpy()
            tuned_score = rank_model(**inputs_tuned).logits.cpu().detach().numpy()
            # logging.info(f"ori_scores: {ori_score.flatten()}, tuned_score:{tuned_score.flatten()}")
            # ori_scores.extend(ori_score.flatten())
            # tuned_scores.extend(tuned_score.flatten())
            ori_scores_instuction.extend(ori_score[:, 6].flatten())
            ori_scores_hones.extend(ori_score[:, 8].flatten())
            ori_scores_help.extend(ori_score[:, 9].flatten())

            tuned_instuction.extend(tuned_score[:, 6].flatten())
            tuned_scores_hones.extend(tuned_score[:, 8].flatten())
            tuned_scores_help.extend(tuned_score[:, 9].flatten())

        # Placeholder reward difference
        n += 1

        # wirte to files
        with open(csv_file, mode='a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            for prompt, response in zip(batch['prompts'], output_peft):
                raw_prompt = extract_raw_prompt(prompt)
                if raw_prompt:
                    writer.writerow([raw_prompt, response])
        
        with open(new_filename, mode='a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            for prompt, response in zip(batch['prompts'], output_original):
                raw_prompt = extract_raw_prompt(prompt)
                if raw_prompt:
                    writer.writerow([raw_prompt, response])

    peft_model.train()
    return ori_scores_hones, ori_scores_help, ori_scores_instuction, tuned_scores_hones, tuned_scores_help, tuned_instuction

def extract_raw_prompt(formatted_prompt):
    prefix = "\n\nHuman:\n"
    suffix = "\n\nAssistant:\n"
    if formatted_prompt.startswith(prefix) and suffix in formatted_prompt:
        start = len(prefix)
        end = formatted_prompt.index(suffix)
        return formatted_prompt[start:end].strip()
    return None  # or raise an error if malformed

def collate_fn_prompt(batch, tokenizer, max_length, device):
    """
      Collate function for handlingsingle dimension data.

      Parameters:
          batch: List of samples in the batch.
          tokenizer: Tokenizer to encode the data.
          max_length: Maximum sequence length for padding/truncation.
          device: Device to move tensors to..

      Returns:
          A dictionary containing batched inputs for each dimension.
      """
    # Organize data into dimensions
    batched_data = {}
    prompts = [item['prompt']  for item in batch]
    
    # Tokenize the inputs (prompt, preferred response, dispreferred response)
    prompts_encoding = tokenizer.batch_encode_plus(
        prompts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True
    )

    batched_data[f'prompts'] = prompts
    batched_data[f'input_ids'] = prompts_encoding['input_ids'].to(device)
    batched_data[f'attention_mask'] = prompts_encoding['attention_mask'].to(device)
    return batched_data

def get_reward_models(reward_name):
   # quantization_config = BitsAndBytesConfig(load_in_8bit=True) 
    rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name, trust_remote_code=True).to(device)
    rewardtokenizer = AutoTokenizer.from_pretrained(reward_name) 
    #rank_model.to(device)
    if rewardtokenizer.pad_token is None:
        rewardtokenizer.pad_token = rewardtokenizer.eos_token
    rank_model.config.pad_token_id = rewardtokenizer.pad_token_id
    return rank_model, rewardtokenizer 

def main():
    args = parse_args()
    model_name = args.model
    checkpoint = args.checkpoint
    use_lora = args.use_lora
    quantization = args.quantization
    #model_name = "meta-llama/Llama-3.2-1B-Instruct" 
    #use_lora = True

    test_config = {
        #"batch_size": 80,
        "batch_size": 20,
        "max_length":512,
        "use_lora": True, 
        "val_data_size": 1000, 
        #"val_batch_size": 50,
        "val_batch_size": 20,
        "beta" : 0.1,
        "loss_type": "sigmoid", 
    }

    set_seed(0)
    checkpoint_path = f"{checkpoint}"
    
    # should match that in the tuning
    if quantization:
        # 1. Load quantized model
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    if use_lora:
        model = PeftModel.from_pretrained(model, checkpoint_path)
    else:
        # not implemented
        pass

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    tokenizer.pad_token = tokenizer.eos_token
    logging.info(f"checkpiont path is {checkpoint_path}")

    # Filter out some columns
    exclude_cols = ["truthfulness", "overall"]
    seed=0

    #test_data = load_from_disk('eval/eval_input_data/safer_test')
    data_name = "openbmb/UltraFeedback"

    # full_train_data = load_and_process_dataset_from_name(data_name, split='train', seed=seed, removed_dimensions = exclude_cols)
    # for dimension, test_data in full_train_data.items(): 
    #     test_data.save_to_disk(f'./eval/eval_input_data/{dimension}_test')

    # exit()

    full_test_data = load_and_process_dataset_from_name(data_name, split='test', seed=seed, removed_dimensions = exclude_cols)
    # get any dimension, since we only need the prompt
    test_data = full_test_data["honesty"]

    seen = set()
    # Filter the dataset to keep only unique 'prompt' values
    test_data = test_data.filter(lambda item, index: item['prompt'] not in seen and not seen.add(item['prompt']), with_indices=True)
    # test_data.save_to_disk(f'./eval/eval_input_data/newpromp')
    # exit()

    test_data_loader = torch.utils.data.DataLoader(test_data, batch_size = test_config["batch_size"],
                                                   shuffle=False,
                                                   collate_fn=partial(collate_fn_prompt, tokenizer=tokenizer,
                                                                      max_length=test_config["max_length"], device=device))
       
    
 
    #reward_name_list = ["Ray2333/gpt2-large-harmless-reward_model",
    #                    "OpenAssistant/reward-model-deberta-v3-large-v2"]
    reward_name_list = ["RLHFlow/RewardModel-Mistral-7B-for-DPA-v1"] 

    original_lists = {}
    tuned_lists = {}
    
    # output filename
    filename = "_".join(checkpoint.split("/")[1:])
    index = filename.find('.pt')
    if index != -1:
        filename = filename[:index]  # Keep everything before .pt
    csv_file = f"eval/response/{filename}.csv"

    #compute rewards
    # for reward in reward_name_list:
    #     rank_model, rewardtokenizer = get_reward_models(reward)
    #     ori_scores_hones, ori_scores_help, ori_scores_instuction, tuned_scores_hones, tuned_scores_help, tuned_instuction = evaluate(model, rank_model, tokenizer, rewardtokenizer, test_data_loader, csv_file, max_length=512)
    #     # original_lists[reward] = ori
    #     # tuned_lists[reward] = tuned
    #     original_lists["honesty"] = ori_scores_hones
    #     original_lists["helpful"] = ori_scores_help
    #     original_lists["instruction"] = ori_scores_instuction
    #     tuned_lists["honesty"] = tuned_scores_hones
    #     tuned_lists["helpful"] = tuned_scores_help
    #     tuned_lists["instruction"] = tuned_instuction
    dict = {"orignal": original_lists, "tuned": tuned_lists}

    # compute the valid loss 
    torch.cuda.empty_cache()
    validation_dataset = load_and_process_dataset_from_name(data_name, split='test', seed=seed, removed_dimensions = exclude_cols)
    dimensions = list(validation_dataset.keys())
    dict_val_dataloaders = {}
    for dimension in dimensions:
        #Disable shuffling, and use a larger batch size
      dict_val_dataloaders[dimension] = torch.utils.data.DataLoader(validation_dataset[dimension], batch_size = test_config["val_batch_size"], 
                                                                    shuffle=False,
                                                                    collate_fn=partial(collate_fn, tokenizer=tokenizer,
                                                                                       max_length=test_config["max_length"], 
                                                                                       device=device))
   
    print(dimensions)
    mean_loss, median_loss, val_reward_accuracy, val_reward_margins, val_perp, val_perp_neg, model_per, model_per_neg = estimate_validation_loss(
      model, None, dict_val_dataloaders, dimensions, test_config, tokenizer, device)
    
    dict ["val_losses"] = mean_loss
    dict ["val_losses_median"] = median_loss
    dict ["model_per"] =  model_per
    dict ["val_perp"] =  val_perp
    logging.info(f"Model Perplexity: {dict['model_per']}")
    logging.info(f"Model Perplexity ratio: {dict['val_perp']}")
    logging.info(f"Model Perplexity (reject): {model_per_neg}")
    logging.info(f"Model Perplexity ratio (rejected): {val_perp_neg}")

    
    with open(f"eval/{filename}", 'wb') as file:
        pickle.dump(dict, file)


if __name__ == "__main__":
    main()

#python simple_dpo/evaluation.py  --model "meta-llama/Llama-3.2-1B-Instruct" --checkpoint "1743730838_pid_869442_model_meta-llamaLlama-3.2-1B-Instruct_pref_0.4_0.6/epoch_1_step_4950_valloss_0.731-0.7619_trailoss0.473-0.7957.ptpeft_checkpoint" --use_lora